{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basics: overfitting a MLP on CIFAR10\n",
    "\n",
    "Training loop over CIFAR10 (40,000 train images, 10,000 test images). What happens if you\n",
    "- switch the training to a GPU? Is it faster?\n",
    "- Remove the `ReLU()`? \n",
    "- Increase the learning rate?\n",
    "- Stack more layers? \n",
    "- Perform more epochs?\n",
    "\n",
    "Can you completely overfit the training set (i.e. get 100% accuracy?)\n",
    "\n",
    "This code is highly non-modulable. Create functions for each specific task. \n",
    "(hint: see [this](https://github.com/pytorch/examples/blob/master/mnist/main.py))\n",
    "\n",
    "Your training went well. Good. Why not save the weights of the network (`net.state_dict()`) using `torch.save()`?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "Epoch 0\n",
      "Train loss 0.0381, Train accuracy 20.31%\n",
      "Train loss 0.0298, Train accuracy 33.37%\n",
      "Train loss 0.0283, Train accuracy 36.81%\n",
      "Train loss 0.0275, Train accuracy 38.50%\n",
      "Train loss 0.0268, Train accuracy 39.95%\n",
      "Train loss 0.0264, Train accuracy 40.74%\n",
      "Train loss 0.0261, Train accuracy 41.38%\n",
      "Train loss 0.0259, Train accuracy 41.93%\n",
      "Epoch 1\n",
      "Train loss 0.0253, Train accuracy 53.59%\n",
      "Train loss 0.0231, Train accuracy 49.87%\n",
      "Train loss 0.0230, Train accuracy 49.38%\n",
      "Train loss 0.0229, Train accuracy 49.60%\n",
      "Train loss 0.0226, Train accuracy 49.82%\n",
      "Train loss 0.0225, Train accuracy 49.95%\n",
      "Train loss 0.0225, Train accuracy 49.99%\n",
      "Train loss 0.0224, Train accuracy 50.09%\n",
      "Epoch 2\n",
      "Train loss 0.0230, Train accuracy 59.69%\n",
      "Train loss 0.0214, Train accuracy 53.51%\n",
      "Train loss 0.0213, Train accuracy 53.27%\n",
      "Train loss 0.0212, Train accuracy 53.50%\n",
      "Train loss 0.0210, Train accuracy 53.86%\n",
      "Train loss 0.0209, Train accuracy 53.94%\n",
      "Train loss 0.0209, Train accuracy 53.86%\n",
      "Train loss 0.0209, Train accuracy 53.90%\n",
      "End of training.\n",
      "\n",
      "End of testing. Test accuracy 50.51%\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import torchvision.transforms as t\n",
    "\n",
    "# define network structure \n",
    "net = nn.Sequential(nn.Linear(3 * 32 * 32, 1000), nn.ReLU(), nn.Linear(1000, 10))\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)\n",
    "\n",
    "# load data\n",
    "to_tensor =  t.ToTensor()\n",
    "normalize = t.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "flatten =  t.Lambda(lambda x:x.view(-1))\n",
    "\n",
    "transform_list = t.Compose([to_tensor, normalize, flatten])\n",
    "train_set = torchvision.datasets.CIFAR10(root='.', train=True, transform=transform_list, download=True)\n",
    "test_set = torchvision.datasets.CIFAR10(root='.', train=False, transform=transform_list, download=True)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=64)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)\n",
    "\n",
    "# === Train === ###\n",
    "net.train()\n",
    "\n",
    "# train loop\n",
    "for epoch in range(3):\n",
    "    train_correct = 0\n",
    "    train_loss = 0\n",
    "    print('Epoch {}'.format(epoch))\n",
    "    \n",
    "    # loop per epoch \n",
    "    for i, (batch, targets) in enumerate(train_loader):\n",
    "\n",
    "        output = net(batch)\n",
    "        loss = criterion(output, targets)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        pred = output.max(1, keepdim=True)[1]\n",
    "        train_correct += pred.eq(targets.view_as(pred)).sum().item()\n",
    "        train_loss += loss\n",
    "\n",
    "        if i % 100 == 10: print('Train loss {:.4f}, Train accuracy {:.2f}%'.format(\n",
    "            train_loss / (i * 64), 100 * train_correct / (i * 64)))\n",
    "        \n",
    "print('End of training.\\n')\n",
    "    \n",
    "# === Test === ###\n",
    "test_correct = 0\n",
    "net.eval()\n",
    "\n",
    "# loop, over whole test set\n",
    "for i, (batch, targets) in enumerate(test_loader):\n",
    "    \n",
    "    output = net(batch)\n",
    "    pred = output.max(1, keepdim=True)[1]\n",
    "    test_correct += pred.eq(targets.view_as(pred)).sum().item()\n",
    "    \n",
    "print('End of testing. Test accuracy {:.2f}%'.format(\n",
    "    100 * test_correct / (len(test_loader) * 64)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, device, train_loader, criterion, optimizer, epoch):\n",
    "    model.train()\n",
    "    print('Epoch {}'.format(epoch))\n",
    "    train_correct = 0\n",
    "    train_loss = 0\n",
    "    train_l = 0\n",
    "    # loop per epoch \n",
    "    for i, (batch, targets) in enumerate(train_loader):\n",
    "        bs = batch.shape[0]\n",
    "        batch = batch.to(device)\n",
    "        targets = targets.to(device)\n",
    "\n",
    "        output = model(batch)\n",
    "        loss = criterion(output, targets)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        pred = output.max(1, keepdim=True)[1]\n",
    "        train_correct += pred.eq(targets.view_as(pred)).sum().item()\n",
    "        train_loss += loss.item()\n",
    "        train_l += bs \n",
    "\n",
    "        if i % 100 == 10: print('Train loss {:.4f}, Train accuracy {:.2f}%'.format(\n",
    "            train_loss / train_l, 100 * train_correct / train_l))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cpu'\n",
    "net = nn.Sequential(nn.Linear(3 * 32 * 32, 1000), nn.ReLU(), nn.Linear(1000, 10))\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)\n",
    "net = net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "Train loss 0.0347, Train accuracy 19.03%\n",
      "Train loss 0.0295, Train accuracy 32.76%\n",
      "Train loss 0.0282, Train accuracy 36.42%\n",
      "Train loss 0.0274, Train accuracy 38.25%\n",
      "Train loss 0.0267, Train accuracy 39.77%\n",
      "Train loss 0.0263, Train accuracy 40.54%\n",
      "Train loss 0.0261, Train accuracy 41.27%\n",
      "Train loss 0.0258, Train accuracy 41.88%\n",
      "CPU times: user 1min 43s, sys: 394 ms, total: 1min 44s\n",
      "Wall time: 15.1 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "train(net,device,train_loader,criterion,optimizer,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "net = nn.Sequential(nn.Linear(3 * 32 * 32, 1000), nn.ReLU(), nn.Linear(1000, 10))\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)\n",
    "net = net.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "Train loss 0.0346, Train accuracy 19.03%\n",
      "Train loss 0.0295, Train accuracy 32.62%\n",
      "Train loss 0.0282, Train accuracy 36.36%\n",
      "Train loss 0.0274, Train accuracy 38.12%\n",
      "Train loss 0.0267, Train accuracy 39.72%\n",
      "Train loss 0.0263, Train accuracy 40.51%\n",
      "Train loss 0.0260, Train accuracy 41.25%\n",
      "Train loss 0.0258, Train accuracy 41.81%\n",
      "CPU times: user 8.97 s, sys: 92.6 ms, total: 9.07 s\n",
      "Wall time: 9.06 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "train(net,device,train_loader,criterion,optimizer,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_correct = 0\n",
    "    test_l =0\n",
    "    for i, (batch, targets) in enumerate(test_loader):\n",
    "        bs = batch.shape[0]\n",
    "        batch = batch.to(device)\n",
    "        targets = targets.to(device)\n",
    "        output = model(batch)\n",
    "        pred = output.max(1, keepdim=True)[1]\n",
    "        test_correct += pred.eq(targets.view_as(pred)).sum().item()\n",
    "        test_l += bs\n",
    "    \n",
    "    print('End of testing. Test accuracy {:.2f}%'.format(\n",
    "    100 * test_correct / test_l))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Train loss 0.0346, Train accuracy 19.18%\n",
      "Train loss 0.0295, Train accuracy 32.87%\n",
      "Train loss 0.0282, Train accuracy 36.43%\n",
      "Train loss 0.0274, Train accuracy 38.24%\n",
      "Train loss 0.0267, Train accuracy 39.85%\n",
      "Train loss 0.0263, Train accuracy 40.63%\n",
      "Train loss 0.0261, Train accuracy 41.27%\n",
      "Train loss 0.0258, Train accuracy 41.88%\n",
      "End of testing. Test accuracy 47.77%\n",
      "Epoch 1\n",
      "Train loss 0.0230, Train accuracy 48.58%\n",
      "Train loss 0.0230, Train accuracy 49.03%\n",
      "Train loss 0.0229, Train accuracy 48.96%\n",
      "Train loss 0.0228, Train accuracy 49.26%\n",
      "Train loss 0.0226, Train accuracy 49.58%\n",
      "Train loss 0.0225, Train accuracy 49.76%\n",
      "Train loss 0.0224, Train accuracy 49.83%\n",
      "Train loss 0.0224, Train accuracy 50.01%\n",
      "End of testing. Test accuracy 49.83%\n",
      "Epoch 2\n",
      "Train loss 0.0210, Train accuracy 52.98%\n",
      "Train loss 0.0212, Train accuracy 52.89%\n",
      "Train loss 0.0212, Train accuracy 52.78%\n",
      "Train loss 0.0211, Train accuracy 53.09%\n",
      "Train loss 0.0209, Train accuracy 53.43%\n",
      "Train loss 0.0208, Train accuracy 53.65%\n",
      "Train loss 0.0208, Train accuracy 53.62%\n",
      "Train loss 0.0208, Train accuracy 53.76%\n",
      "End of testing. Test accuracy 50.77%\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "net = nn.Sequential(nn.Linear(3 * 32 * 32, 1000), nn.ReLU(), nn.Linear(1000, 10))\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)\n",
    "net = net.to(device)\n",
    "\n",
    "for i in range(3):\n",
    "    train(net,device,train_loader,criterion,optimizer,i)\n",
    "    test(net,device,test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Train loss 0.0334, Train accuracy 22.30%\n",
      "Train loss 0.0301, Train accuracy 32.38%\n",
      "Train loss 0.0296, Train accuracy 34.12%\n",
      "Train loss 0.0292, Train accuracy 34.63%\n",
      "Train loss 0.0288, Train accuracy 35.68%\n",
      "Train loss 0.0287, Train accuracy 35.93%\n",
      "Train loss 0.0286, Train accuracy 36.26%\n",
      "Train loss 0.0285, Train accuracy 36.36%\n",
      "End of testing. Test accuracy 38.81%\n",
      "Epoch 1\n",
      "Train loss 0.0267, Train accuracy 39.06%\n",
      "Train loss 0.0272, Train accuracy 39.46%\n",
      "Train loss 0.0275, Train accuracy 39.07%\n",
      "Train loss 0.0275, Train accuracy 38.87%\n",
      "Train loss 0.0274, Train accuracy 39.30%\n",
      "Train loss 0.0274, Train accuracy 39.20%\n",
      "Train loss 0.0274, Train accuracy 39.30%\n",
      "Train loss 0.0274, Train accuracy 39.21%\n",
      "End of testing. Test accuracy 39.02%\n",
      "Epoch 2\n",
      "Train loss 0.0264, Train accuracy 41.05%\n",
      "Train loss 0.0269, Train accuracy 40.70%\n",
      "Train loss 0.0272, Train accuracy 39.99%\n",
      "Train loss 0.0272, Train accuracy 39.89%\n",
      "Train loss 0.0271, Train accuracy 40.25%\n",
      "Train loss 0.0271, Train accuracy 40.14%\n",
      "Train loss 0.0271, Train accuracy 40.24%\n",
      "Train loss 0.0271, Train accuracy 40.11%\n",
      "End of testing. Test accuracy 39.23%\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "net = nn.Sequential(nn.Linear(3 * 32 * 32, 1000), nn.Linear(1000, 10))\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)\n",
    "net = net.to(device)\n",
    "\n",
    "for i in range(3):\n",
    "    train(net,device,train_loader,criterion,optimizer,i)\n",
    "    test(net,device,test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Train loss 0.0342, Train accuracy 22.73%\n",
      "Train loss 0.2320, Train accuracy 22.58%\n",
      "Train loss 7.4938, Train accuracy 20.96%\n",
      "Train loss 321.0443, Train accuracy 20.28%\n",
      "Train loss 10331.5943, Train accuracy 20.43%\n",
      "Train loss 436773.5373, Train accuracy 20.25%\n",
      "Train loss 16117542.1515, Train accuracy 20.18%\n",
      "Train loss 638806285.0275, Train accuracy 20.14%\n",
      "End of testing. Test accuracy 21.16%\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "net = nn.Sequential(nn.Linear(3 * 32 * 32, 1000), nn.ReLU(), nn.Linear(1000, 10))\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr = 0.1, momentum=0.9)\n",
    "net = net.to(device)\n",
    "\n",
    "for i in range(1):\n",
    "    train(net,device,train_loader,criterion,optimizer,i)\n",
    "    test(net,device,test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Train loss 0.0359, Train accuracy 9.94%\n",
      "Train loss 0.0337, Train accuracy 21.44%\n",
      "Train loss 0.0316, Train accuracy 26.47%\n",
      "Train loss 0.0303, Train accuracy 29.83%\n",
      "Train loss 0.0292, Train accuracy 32.75%\n",
      "Train loss 0.0285, Train accuracy 34.49%\n",
      "Train loss 0.0279, Train accuracy 35.83%\n",
      "Train loss 0.0275, Train accuracy 37.03%\n",
      "End of testing. Test accuracy 47.34%\n",
      "Epoch 1\n",
      "Train loss 0.0236, Train accuracy 45.60%\n",
      "Train loss 0.0234, Train accuracy 47.02%\n",
      "Train loss 0.0233, Train accuracy 47.32%\n",
      "Train loss 0.0231, Train accuracy 47.94%\n",
      "Train loss 0.0229, Train accuracy 48.40%\n",
      "Train loss 0.0227, Train accuracy 48.73%\n",
      "Train loss 0.0226, Train accuracy 48.99%\n",
      "Train loss 0.0225, Train accuracy 49.31%\n",
      "End of testing. Test accuracy 51.05%\n",
      "Epoch 2\n",
      "Train loss 0.0209, Train accuracy 53.12%\n",
      "Train loss 0.0210, Train accuracy 53.03%\n",
      "Train loss 0.0209, Train accuracy 53.20%\n",
      "Train loss 0.0208, Train accuracy 53.55%\n",
      "Train loss 0.0206, Train accuracy 53.88%\n",
      "Train loss 0.0205, Train accuracy 54.02%\n",
      "Train loss 0.0204, Train accuracy 54.11%\n",
      "Train loss 0.0203, Train accuracy 54.38%\n",
      "End of testing. Test accuracy 52.43%\n",
      "Epoch 3\n",
      "Train loss 0.0188, Train accuracy 58.52%\n",
      "Train loss 0.0191, Train accuracy 57.00%\n",
      "Train loss 0.0191, Train accuracy 57.26%\n",
      "Train loss 0.0190, Train accuracy 57.53%\n",
      "Train loss 0.0188, Train accuracy 57.78%\n",
      "Train loss 0.0187, Train accuracy 57.98%\n",
      "Train loss 0.0186, Train accuracy 58.20%\n",
      "Train loss 0.0185, Train accuracy 58.45%\n",
      "End of testing. Test accuracy 53.32%\n",
      "Epoch 4\n",
      "Train loss 0.0168, Train accuracy 62.93%\n",
      "Train loss 0.0174, Train accuracy 60.82%\n",
      "Train loss 0.0174, Train accuracy 60.86%\n",
      "Train loss 0.0173, Train accuracy 61.40%\n",
      "Train loss 0.0171, Train accuracy 61.74%\n",
      "Train loss 0.0170, Train accuracy 61.90%\n",
      "Train loss 0.0169, Train accuracy 62.16%\n",
      "Train loss 0.0168, Train accuracy 62.37%\n",
      "End of testing. Test accuracy 53.35%\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "net = nn.Sequential(\n",
    "    nn.Linear(3 * 32 * 32, 1000), \n",
    "    nn.ReLU(), \n",
    "    nn.Linear(1000, 1000),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(1000, 1000),\n",
    "    nn.ReLU(), \n",
    "    nn.Linear(1000, 10))\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9)\n",
    "net = net.to(device)\n",
    "\n",
    "for i in range(5):\n",
    "    train(net,device,train_loader,criterion,optimizer,i)\n",
    "    test(net,device,test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Train loss 0.0150, Train accuracy 66.34%\n",
      "Train loss 0.0156, Train accuracy 65.53%\n",
      "Train loss 0.0157, Train accuracy 65.02%\n",
      "Train loss 0.0156, Train accuracy 65.37%\n",
      "Train loss 0.0154, Train accuracy 65.65%\n",
      "Train loss 0.0153, Train accuracy 65.75%\n",
      "Train loss 0.0152, Train accuracy 65.98%\n",
      "Train loss 0.0151, Train accuracy 66.20%\n",
      "Epoch 1\n",
      "Train loss 0.0131, Train accuracy 69.60%\n",
      "Train loss 0.0140, Train accuracy 68.85%\n",
      "Train loss 0.0141, Train accuracy 68.80%\n",
      "Train loss 0.0139, Train accuracy 69.13%\n",
      "Train loss 0.0137, Train accuracy 69.51%\n",
      "Train loss 0.0136, Train accuracy 69.79%\n",
      "Train loss 0.0135, Train accuracy 70.11%\n",
      "Train loss 0.0134, Train accuracy 70.40%\n",
      "Epoch 2\n",
      "Train loss 0.0113, Train accuracy 73.15%\n",
      "Train loss 0.0122, Train accuracy 72.26%\n",
      "Train loss 0.0124, Train accuracy 72.37%\n",
      "Train loss 0.0122, Train accuracy 72.72%\n",
      "Train loss 0.0121, Train accuracy 73.15%\n",
      "Train loss 0.0119, Train accuracy 73.38%\n",
      "Train loss 0.0119, Train accuracy 73.61%\n",
      "Train loss 0.0117, Train accuracy 73.91%\n",
      "Epoch 3\n",
      "Train loss 0.0095, Train accuracy 79.83%\n",
      "Train loss 0.0106, Train accuracy 76.18%\n",
      "Train loss 0.0108, Train accuracy 76.20%\n",
      "Train loss 0.0106, Train accuracy 76.60%\n",
      "Train loss 0.0105, Train accuracy 76.79%\n",
      "Train loss 0.0105, Train accuracy 76.89%\n",
      "Train loss 0.0104, Train accuracy 77.02%\n",
      "Train loss 0.0103, Train accuracy 77.22%\n",
      "Epoch 4\n",
      "Train loss 0.0089, Train accuracy 79.12%\n",
      "Train loss 0.0096, Train accuracy 78.24%\n",
      "Train loss 0.0097, Train accuracy 78.30%\n",
      "Train loss 0.0097, Train accuracy 78.44%\n",
      "Train loss 0.0095, Train accuracy 78.74%\n",
      "Train loss 0.0094, Train accuracy 79.02%\n",
      "Train loss 0.0094, Train accuracy 78.99%\n",
      "Train loss 0.0094, Train accuracy 78.93%\n",
      "Epoch 5\n",
      "Train loss 0.0082, Train accuracy 80.54%\n",
      "Train loss 0.0090, Train accuracy 79.60%\n",
      "Train loss 0.0090, Train accuracy 79.74%\n",
      "Train loss 0.0090, Train accuracy 79.72%\n",
      "Train loss 0.0088, Train accuracy 80.13%\n",
      "Train loss 0.0086, Train accuracy 80.55%\n",
      "Train loss 0.0086, Train accuracy 80.63%\n",
      "Train loss 0.0086, Train accuracy 80.65%\n",
      "Epoch 6\n",
      "Train loss 0.0073, Train accuracy 83.24%\n",
      "Train loss 0.0080, Train accuracy 81.78%\n",
      "Train loss 0.0079, Train accuracy 82.08%\n",
      "Train loss 0.0079, Train accuracy 82.01%\n",
      "Train loss 0.0078, Train accuracy 82.17%\n",
      "Train loss 0.0076, Train accuracy 82.67%\n",
      "Train loss 0.0076, Train accuracy 82.71%\n",
      "Train loss 0.0076, Train accuracy 82.69%\n",
      "Epoch 7\n",
      "Train loss 0.0077, Train accuracy 82.95%\n",
      "Train loss 0.0074, Train accuracy 83.69%\n",
      "Train loss 0.0072, Train accuracy 83.64%\n",
      "Train loss 0.0073, Train accuracy 83.24%\n",
      "Train loss 0.0071, Train accuracy 83.65%\n",
      "Train loss 0.0070, Train accuracy 84.04%\n",
      "Train loss 0.0069, Train accuracy 84.21%\n",
      "Train loss 0.0069, Train accuracy 84.25%\n",
      "Epoch 8\n",
      "Train loss 0.0052, Train accuracy 87.93%\n",
      "Train loss 0.0063, Train accuracy 85.19%\n",
      "Train loss 0.0063, Train accuracy 85.51%\n",
      "Train loss 0.0062, Train accuracy 85.73%\n",
      "Train loss 0.0062, Train accuracy 85.91%\n",
      "Train loss 0.0061, Train accuracy 86.15%\n",
      "Train loss 0.0061, Train accuracy 86.18%\n",
      "Train loss 0.0061, Train accuracy 86.31%\n",
      "Epoch 9\n",
      "Train loss 0.0046, Train accuracy 90.06%\n",
      "Train loss 0.0055, Train accuracy 87.53%\n",
      "Train loss 0.0055, Train accuracy 87.45%\n",
      "Train loss 0.0054, Train accuracy 87.59%\n",
      "Train loss 0.0055, Train accuracy 87.49%\n",
      "Train loss 0.0054, Train accuracy 87.69%\n",
      "Train loss 0.0054, Train accuracy 87.60%\n",
      "Train loss 0.0054, Train accuracy 87.61%\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    train(net,device,train_loader,criterion,optimizer,i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Train loss 0.0045, Train accuracy 89.63%\n",
      "Train loss 0.0054, Train accuracy 87.74%\n",
      "Train loss 0.0051, Train accuracy 88.40%\n",
      "Train loss 0.0050, Train accuracy 88.77%\n",
      "Train loss 0.0049, Train accuracy 88.85%\n",
      "Train loss 0.0049, Train accuracy 88.92%\n",
      "Train loss 0.0049, Train accuracy 88.90%\n",
      "Train loss 0.0048, Train accuracy 89.03%\n",
      "Epoch 1\n",
      "Train loss 0.0035, Train accuracy 92.47%\n",
      "Train loss 0.0044, Train accuracy 90.47%\n",
      "Train loss 0.0044, Train accuracy 90.23%\n",
      "Train loss 0.0043, Train accuracy 90.42%\n",
      "Train loss 0.0043, Train accuracy 90.31%\n",
      "Train loss 0.0042, Train accuracy 90.52%\n",
      "Train loss 0.0042, Train accuracy 90.53%\n",
      "Train loss 0.0042, Train accuracy 90.69%\n",
      "Epoch 2\n",
      "Train loss 0.0038, Train accuracy 91.48%\n",
      "Train loss 0.0041, Train accuracy 90.71%\n",
      "Train loss 0.0040, Train accuracy 91.00%\n",
      "Train loss 0.0039, Train accuracy 91.21%\n",
      "Train loss 0.0039, Train accuracy 91.26%\n",
      "Train loss 0.0039, Train accuracy 91.32%\n",
      "Train loss 0.0038, Train accuracy 91.43%\n",
      "Train loss 0.0038, Train accuracy 91.52%\n",
      "Epoch 3\n",
      "Train loss 0.0032, Train accuracy 91.90%\n",
      "Train loss 0.0036, Train accuracy 91.96%\n",
      "Train loss 0.0036, Train accuracy 91.89%\n",
      "Train loss 0.0035, Train accuracy 92.04%\n",
      "Train loss 0.0037, Train accuracy 91.67%\n",
      "Train loss 0.0036, Train accuracy 91.91%\n",
      "Train loss 0.0035, Train accuracy 92.15%\n",
      "Train loss 0.0034, Train accuracy 92.34%\n",
      "Epoch 4\n",
      "Train loss 0.0026, Train accuracy 94.32%\n",
      "Train loss 0.0032, Train accuracy 93.22%\n",
      "Train loss 0.0031, Train accuracy 93.26%\n",
      "Train loss 0.0032, Train accuracy 93.12%\n",
      "Train loss 0.0033, Train accuracy 92.81%\n",
      "Train loss 0.0032, Train accuracy 92.98%\n",
      "Train loss 0.0031, Train accuracy 93.12%\n",
      "Train loss 0.0031, Train accuracy 93.23%\n",
      "Epoch 5\n",
      "Train loss 0.0025, Train accuracy 94.03%\n",
      "Train loss 0.0031, Train accuracy 93.20%\n",
      "Train loss 0.0030, Train accuracy 93.33%\n",
      "Train loss 0.0029, Train accuracy 93.68%\n",
      "Train loss 0.0030, Train accuracy 93.37%\n",
      "Train loss 0.0029, Train accuracy 93.48%\n",
      "Train loss 0.0029, Train accuracy 93.52%\n",
      "Train loss 0.0029, Train accuracy 93.60%\n",
      "Epoch 6\n",
      "Train loss 0.0032, Train accuracy 92.47%\n",
      "Train loss 0.0028, Train accuracy 93.81%\n",
      "Train loss 0.0028, Train accuracy 93.52%\n",
      "Train loss 0.0027, Train accuracy 94.00%\n",
      "Train loss 0.0027, Train accuracy 93.97%\n",
      "Train loss 0.0027, Train accuracy 94.10%\n",
      "Train loss 0.0027, Train accuracy 94.15%\n",
      "Train loss 0.0026, Train accuracy 94.31%\n",
      "Epoch 7\n",
      "Train loss 0.0025, Train accuracy 95.03%\n",
      "Train loss 0.0023, Train accuracy 95.03%\n",
      "Train loss 0.0025, Train accuracy 94.61%\n",
      "Train loss 0.0025, Train accuracy 94.70%\n",
      "Train loss 0.0025, Train accuracy 94.61%\n",
      "Train loss 0.0024, Train accuracy 94.74%\n",
      "Train loss 0.0024, Train accuracy 94.73%\n",
      "Train loss 0.0024, Train accuracy 94.81%\n",
      "Epoch 8\n",
      "Train loss 0.0019, Train accuracy 95.17%\n",
      "Train loss 0.0023, Train accuracy 94.62%\n",
      "Train loss 0.0022, Train accuracy 94.87%\n",
      "Train loss 0.0022, Train accuracy 94.87%\n",
      "Train loss 0.0022, Train accuracy 94.96%\n",
      "Train loss 0.0021, Train accuracy 95.10%\n",
      "Train loss 0.0022, Train accuracy 95.14%\n",
      "Train loss 0.0021, Train accuracy 95.19%\n",
      "Epoch 9\n",
      "Train loss 0.0028, Train accuracy 93.32%\n",
      "Train loss 0.0022, Train accuracy 95.14%\n",
      "Train loss 0.0022, Train accuracy 95.40%\n",
      "Train loss 0.0022, Train accuracy 95.22%\n",
      "Train loss 0.0022, Train accuracy 95.27%\n",
      "Train loss 0.0021, Train accuracy 95.28%\n",
      "Train loss 0.0021, Train accuracy 95.42%\n",
      "Train loss 0.0020, Train accuracy 95.49%\n"
     ]
    }
   ],
   "source": [
    "#optimizer = torch.optim.SGD(net.parameters(), lr = 0.001, momentum=0.9)\n",
    "for i in range(10):\n",
    "    train(net,device,train_loader,criterion,optimizer,i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Train loss 0.0021, Train accuracy 95.17%\n",
      "Train loss 0.0023, Train accuracy 95.05%\n",
      "Train loss 0.0022, Train accuracy 95.17%\n",
      "Train loss 0.0022, Train accuracy 95.30%\n",
      "Train loss 0.0021, Train accuracy 95.51%\n",
      "Train loss 0.0019, Train accuracy 95.82%\n",
      "Train loss 0.0018, Train accuracy 96.14%\n",
      "Train loss 0.0016, Train accuracy 96.48%\n",
      "Epoch 1\n",
      "Train loss 0.0005, Train accuracy 99.01%\n",
      "Train loss 0.0007, Train accuracy 98.90%\n",
      "Train loss 0.0007, Train accuracy 98.87%\n",
      "Train loss 0.0007, Train accuracy 98.82%\n",
      "Train loss 0.0007, Train accuracy 98.82%\n",
      "Train loss 0.0007, Train accuracy 98.87%\n",
      "Train loss 0.0006, Train accuracy 98.94%\n",
      "Train loss 0.0006, Train accuracy 99.02%\n",
      "Epoch 2\n",
      "Train loss 0.0003, Train accuracy 99.57%\n",
      "Train loss 0.0004, Train accuracy 99.47%\n",
      "Train loss 0.0004, Train accuracy 99.45%\n",
      "Train loss 0.0005, Train accuracy 99.40%\n",
      "Train loss 0.0005, Train accuracy 99.38%\n",
      "Train loss 0.0004, Train accuracy 99.40%\n",
      "Train loss 0.0004, Train accuracy 99.42%\n",
      "Train loss 0.0004, Train accuracy 99.46%\n",
      "Epoch 3\n",
      "Train loss 0.0003, Train accuracy 99.72%\n",
      "Train loss 0.0003, Train accuracy 99.63%\n",
      "Train loss 0.0003, Train accuracy 99.66%\n",
      "Train loss 0.0003, Train accuracy 99.62%\n",
      "Train loss 0.0003, Train accuracy 99.60%\n",
      "Train loss 0.0003, Train accuracy 99.62%\n",
      "Train loss 0.0003, Train accuracy 99.64%\n",
      "Train loss 0.0003, Train accuracy 99.66%\n",
      "Epoch 4\n",
      "Train loss 0.0002, Train accuracy 99.72%\n",
      "Train loss 0.0003, Train accuracy 99.75%\n",
      "Train loss 0.0003, Train accuracy 99.77%\n",
      "Train loss 0.0003, Train accuracy 99.72%\n",
      "Train loss 0.0003, Train accuracy 99.72%\n",
      "Train loss 0.0003, Train accuracy 99.73%\n",
      "Train loss 0.0003, Train accuracy 99.75%\n",
      "Train loss 0.0003, Train accuracy 99.76%\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.SGD(net.parameters(), lr = 0.001, momentum=0.9)\n",
    "for i in range(5):\n",
    "    train(net,device,train_loader,criterion,optimizer,i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "End of testing. Test accuracy 57.61%\n"
     ]
    }
   ],
   "source": [
    "test(net,device,test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Autograd tips and tricks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Pointers are everywhere!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[ 0.4992,  0.0703],\n",
      "        [ 0.1255, -0.3164]], requires_grad=True)\n",
      "Parameter containing:\n",
      "tensor([[ 0.4990,  0.0657],\n",
      "        [ 0.1253, -0.3211]], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "net = nn.Linear(2, 2)\n",
    "w = net.weight\n",
    "print(w)\n",
    "\n",
    "x = torch.rand(1, 2)\n",
    "y = net(x).sum()\n",
    "y.backward()\n",
    "net.weight.data -= 0.01 * net.weight.grad # <--- What is this?\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.5725,  0.2102],\n",
      "        [-0.6022,  0.3308]], grad_fn=<CloneBackward0>)\n",
      "tensor([[-0.5725,  0.2102],\n",
      "        [-0.6022,  0.3308]], grad_fn=<CloneBackward0>)\n"
     ]
    }
   ],
   "source": [
    "net = nn.Linear(2, 2)\n",
    "w = net.weight.clone()\n",
    "print(w)\n",
    "\n",
    "x = torch.rand(1, 2)\n",
    "y = net(x).sum()\n",
    "y.backward()\n",
    "net.weight.data -= 0.01 * net.weight.grad # <--- What is this?\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sharing weights "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0.5027, 1.1948],\n",
      "        [0.5017, 1.1942]])\n",
      "tensor([[0.5027, 1.1948],\n",
      "        [0.5017, 1.1942]])\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n",
    "net[0].weight = net[1].weight  # weight sharing\n",
    "\n",
    "x = torch.rand(1, 2)\n",
    "y = net(x).sum()\n",
    "y.backward()\n",
    "print(net[0].weight.grad)\n",
    "print(net[1].weight.grad)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dldiy",
   "language": "python",
   "name": "dldiy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
